{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CzXqLs5MUU-D"
   },
   "source": [
    "# Learning to read with TensorFlow and Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
    " <td>\n",
    "    <a target=\"_blank\"        href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/community/en/dev-summit-2020/learning_to_read_with_tensorFlow_and_keras.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
    "  </td>\n",
    "    <td>\n",
    "    <a target=\"_blank\" href=\"https://github.com/tensorflow/examples/blob/master/community/en/dev-summit-2020/learning_to_read_with_tensorFlow_and_keras.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ES1UsjrjUZP0"
   },
   "source": [
    "## Install necessary packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "YlUDLPRIhCe7"
   },
   "outputs": [],
   "source": [
    "!pip install tf-nightly\n",
    "!pip install tensorflow-addons\n",
    "!pip install keras-tuner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7i6tA1ouhOEA"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_addons as tfa\n",
    "tf.__version__\n",
    "dir(tfa.seq2seq)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "C2fQ54fmUcdd"
   },
   "source": [
    "## Download the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "s5zOIb-HlavM"
   },
   "outputs": [],
   "source": [
    "!wget http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz\n",
    "!tar -xf CBTest.tgz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "QTT8GUxZmNAS"
   },
   "outputs": [],
   "source": [
    "!ls CBTest/data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "IvHM6W5KUftY"
   },
   "source": [
    "## Preprocess the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-TMKerSvjsez"
   },
   "outputs": [],
   "source": [
    "lines = tf.data.TextLineDataset('CBTest/data/cbt_train.txt')\n",
    "for row in lines.take(3):\n",
    "  print(row)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "K5RASRVIq6Zs"
   },
   "outputs": [],
   "source": [
    "lines = lines.filter(\n",
    "    lambda x: not tf.strings.regex_full_match(x, \"_BOOK_TITLE_.*\"))\n",
    "\n",
    "punctuation = r'[!\"#$%&()\\*\\+,-\\./:;<=>?@\\[\\\\\\]^_`{|}~\\']'\n",
    "lines = lines.map(\n",
    "    lambda x: tf.strings.regex_replace(x, punctuation, ' ') )\n",
    "\n",
    "for row in lines.take(3):\n",
    "  print(row)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "AtDjQmchufuN"
   },
   "outputs": [],
   "source": [
    "# Split the lines on `spaces`.\n",
    "words = lines.map(tf.strings.split)\n",
    "\n",
    "# Batch them into 11 words per batch. This way\n",
    "# the first 10 words is the training data and the \n",
    "# 11th word is the prediction word.\n",
    "wordsets = words.unbatch().batch(11)\n",
    "\n",
    "for row in wordsets.take(3):\n",
    "  print(row)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NebCV-gQzE7o"
   },
   "outputs": [],
   "source": [
    "def get_example_label(row):\n",
    "  example = tf.strings.reduce_join(row[:-1], separator=' ')\n",
    "  example = tf.expand_dims(example, axis=0)\n",
    "  label = row[-1:]\n",
    "  return example, label\n",
    "\n",
    "data = wordsets.map(get_example_label)\n",
    "data = data.shuffle(1000)\n",
    "for row in data.take(3):\n",
    "  print(row)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "U7SZY4QyVEE1"
   },
   "source": [
    "Use the `TextVectorization` layer to tokenize the training data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "MEo9fAx8lAMF"
   },
   "outputs": [],
   "source": [
    "max_features = 5000  # Maximum vocab size.\n",
    "\n",
    "vectorize_layer = tf.keras.layers.experimental.preprocessing.TextVectorization(\n",
    "  max_tokens=max_features,\n",
    "  output_sequence_length=10)\n",
    "\n",
    "vectorize_layer.adapt(lines.batch(64))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-J0ueEWWpiho"
   },
   "outputs": [],
   "source": [
    "vectorize_layer.get_vocabulary()[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "eLZwUeDUprhu"
   },
   "outputs": [],
   "source": [
    "vectorize_layer.get_vocabulary()[-5:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "n1utBbnqrjHk"
   },
   "outputs": [],
   "source": [
    "for batch in data.batch(3).take(1):\n",
    "  print(batch[0])\n",
    "  print(vectorize_layer(batch[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "LGEmbAPqVKCS"
   },
   "source": [
    "## Create the Encoder-Decoder based model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "o8VyzIOHVXhr"
   },
   "outputs": [],
   "source": [
    "class EncoderDecoder(tf.keras.Model):\n",
    "  def __init__(self, max_features=5000, embedding_dims=200, rnn_units=1024):\n",
    "    super().__init__()\n",
    "    self.max_features = max_features\n",
    "    self.vectorize_layer = tf.keras.layers.experimental.preprocessing.TextVectorization(\n",
    "        max_tokens=max_features,\n",
    "        output_sequence_length=10)\n",
    "    self.encoder_embedding = tf.keras.layers.Embedding(\n",
    "        max_features + 1, embedding_dims)\n",
    "    self.lstm_layer = tf.keras.layers.LSTM(rnn_units, return_state=True)\n",
    "\n",
    "    self.decoder_embedding = tf.keras.layers.Embedding(\n",
    "        max_features + 1, embedding_dims)\n",
    "    sampler = tfa.seq2seq.sampler.TrainingSampler()\n",
    "    decoder_cell = tf.keras.layers.LSTMCell(rnn_units)\n",
    "    projection_layer = tf.keras.layers.Dense(max_features)\n",
    "    self.decoder = tfa.seq2seq.BasicDecoder(\n",
    "        decoder_cell, sampler, output_layer=projection_layer)\n",
    "    \n",
    "    self.attention = tf.keras.layers.Attention()\n",
    "\n",
    "  def train_step(self, data):\n",
    "    x, y = data[0], data[1]\n",
    "    x = self.vectorize_layer(x)\n",
    "    # The vectorize layer pads, but we only need the first val for labels\n",
    "    y = self.vectorize_layer(y)[:, 0:1]\n",
    "    y_one_hot = tf.one_hot(y, self.max_features)\n",
    "\n",
    "    with tf.GradientTape() as tape:\n",
    "      embedded_inputs = self.encoder_embedding(x)\n",
    "      encoder_outputs, state_h, state_c = self.lstm_layer(embedded_inputs)\n",
    "      \n",
    "      attn_output = self.attention([encoder_outputs, state_h])\n",
    "      attn_output = tf.expand_dims(attn_output, axis=1)\n",
    "      \n",
    "      targets = self.decoder_embedding(tf.zeros_like(y))\n",
    "      concat_output = tf.concat([targets, attn_output], axis=-1)\n",
    "      outputs, _, _ = self.decoder(\n",
    "          concat_output, initial_state=[state_h, state_c])\n",
    "      \n",
    "      y_pred = outputs.rnn_output\n",
    "      \n",
    "      loss = self.compiled_loss(\n",
    "          y_one_hot, \n",
    "          y_pred, \n",
    "          regularization_losses=self.losses)\n",
    "    \n",
    "    trainable_variables = self.trainable_variables\n",
    "    gradients = tape.gradient(loss, trainable_variables)\n",
    "    self.optimizer.apply_gradients(zip(gradients, trainable_variables))\n",
    "\n",
    "    self.compiled_metrics.update_state(y_one_hot, y_pred)\n",
    "    return {m.name: m.result() for m in self.metrics}\n",
    "\n",
    "  def predict_step(self, data, select_from_top_n=1):\n",
    "    x = data\n",
    "    if isinstance(x, tuple) and len(x) == 2:\n",
    "      x = x[0]\n",
    "    x = self.vectorize_layer(x)\n",
    "    embedded_inputs = self.encoder_embedding(x)\n",
    "    encoder_outputs, state_h, state_c = self.lstm_layer(embedded_inputs)\n",
    "    attn_output = self.attention([encoder_outputs, state_h])\n",
    "    attn_output = tf.expand_dims(attn_output, axis=1)\n",
    "    \n",
    "    targets = self.decoder_embedding(tf.zeros_like(x[:, -1:]))\n",
    "    concat_output = tf.concat([targets, attn_output], axis=-1)\n",
    "    outputs, _, _ = self.decoder(\n",
    "        concat_output, initial_state=[state_h, state_c])\n",
    "    \n",
    "    y_pred = tf.squeeze(outputs.rnn_output, axis=1)\n",
    "    top_n = tf.argsort(\n",
    "        y_pred[:, 2:], axis=1, direction='DESCENDING')[: ,:select_from_top_n]\n",
    "    chosen_indices = tf.random.uniform(\n",
    "        [top_n.shape[0], 1], minval=0, maxval=select_from_top_n, \n",
    "        dtype=tf.dtypes.int32)\n",
    "    counter = tf.expand_dims(tf.range(0, top_n.shape[0]), axis=1)\n",
    "    indices = tf.concat([counter, chosen_indices], axis=1)\n",
    "    choices = tf.gather_nd(top_n, indices)\n",
    "    words = [self.vectorize_layer.get_vocabulary()[i] for i in choices]\n",
    "    return words\n",
    "\n",
    "  def predict(self, starting_string, num_steps=50, select_from_top_n=1):\n",
    "    s = tf.compat.as_bytes(starting_string).split(b' ')\n",
    "    for _ in range(num_steps):\n",
    "      windowed = [b' '.join(s[-10:])]\n",
    "      pred = self.predict_step([windowed], select_from_top_n=select_from_top_n)\n",
    "      s.append(pred[0])\n",
    "    return b' '.join(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "k0zSE_N76u2E"
   },
   "outputs": [],
   "source": [
    "model = EncoderDecoder()\n",
    "model.compile(\n",
    "    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), \n",
    "    optimizer='adam', \n",
    "    metrics=['accuracy'])\n",
    "model.vectorize_layer.adapt(lines.batch(256))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bYUiLYeq9ZqK"
   },
   "outputs": [],
   "source": [
    "model.fit(data.batch(256), epochs=30, callbacks=[tf.keras.callbacks.ModelCheckpoint('text_gen_ckpt')])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "9Xmj6MXAGsAI"
   },
   "outputs": [],
   "source": [
    "model.fit(data.batch(256), epochs=10, callbacks=[tf.keras.callbacks.ModelCheckpoint('text_gen_ckpt')])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "UIPq5T58MHn1"
   },
   "outputs": [],
   "source": [
    "!ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "66HekuGbhKnH"
   },
   "outputs": [],
   "source": [
    "model.load_weights('text_gen_ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "o2vweexwQmdu"
   },
   "outputs": [],
   "source": [
    "print(model.predict('The mouse and the rabbit went in together'))\n",
    "\n",
    "print(model.predict('Once upon a time there was a Queen named Darling'))\n",
    "\n",
    "print(model.predict('In a city far from here the teacup shook upon the table'))\n",
    "\n",
    "print(model.predict('It was a strange and quiet theater and the people watched from home'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "7sYGSWnsVVNn"
   },
   "source": [
    "## Use `keras-tuner` to tune the hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "xUp_sz27d_TM"
   },
   "outputs": [],
   "source": [
    "import kerastuner as kt\n",
    "\n",
    "def build_model(hp):\n",
    "  model = EncoderDecoder(\n",
    "      rnn_units=hp.Int('units', min_value=256, max_value=1200, step=256))\n",
    "  \n",
    "  model.compile(\n",
    "      optimizer=tf.keras.optimizers.Adam(\n",
    "            hp.Choice('learning_rate', values=[1e-3, 1e-4, 3e-4])),\n",
    "      loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),                 \n",
    "      metrics=['accuracy'])\n",
    "  \n",
    "  model.vectorize_layer.adapt(lines.batch(256))\n",
    "  return model\n",
    "\n",
    "tuner = kt.tuners.RandomSearch(\n",
    "    build_model,\n",
    "    objective='accuracy',\n",
    "    max_trials=15,\n",
    "    executions_per_trial=1,\n",
    "    directory='my_dir',\n",
    "    project_name='text_generation')\n",
    "\n",
    "tuner.search(\n",
    "    data.batch(256), \n",
    "    epochs=10, \n",
    "    callbacks=[tf.keras.callbacks.ModelCheckpoint('text_gen_ckpt')])"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Learning to read with TensorFlow and Keras",
   "private_outputs": true,
   "provenance": [
    {
     "file_id": "1_4JR0tUwMNB2eseR0HpaHykbtiAoF1A5",
     "timestamp": 1583790881369
    },
    {
     "file_id": "1BeifgsKnAdCnJ2WMtZ8yQxrfsG1pcMLB",
     "timestamp": 1582991967988
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
